#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os

"""
    本篇attention思想
        将bi-lstm最后时刻输出的h和c进行concat作为query
        将每个时刻双向输出的h作为 key
        将每个时刻双向输出的h作为 value
"""
os.environ["CUDA_VISIBLE_DEVICES"] = '7'

tf.reset_default_graph()

# Bi-LSTM(Attention) Parameters
embedding_dim = 2
n_hidden = 5  # number of hidden units in one cell
n_step = 3  # all sentence is consist of 3 words
n_class = 2  # 0 or 1

# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
# 1 is good, 0 is not good.
labels = [1, 1, 1, 0, 0, 0]

word_list = list(set(" ".join(sentences).split()))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict)

input_batch = []
for sen in sentences:
    input_batch.append(np.asarray([word_dict[n] for n in sen.split()]))

target_batch = []
for out in labels:
    # ONE-HOT:To using Tensor Softmax Loss function
    target_batch.append(np.eye(n_class)[out])

# LSTM Model
X = tf.placeholder(tf.int32, [None, n_step])
Y = tf.placeholder(tf.int32, [None, n_class])
out = tf.Variable(tf.random_normal([n_hidden * 2, n_class]))

embedding = tf.Variable(tf.random_uniform([vocab_size, embedding_dim]))
# [batch_size, n_step, embedding_dim]
input = tf.nn.embedding_lookup(embedding, X)

lstm_fw_cell = tf.nn.rnn_cell.LSTMCell(n_hidden)
lstm_bw_cell = tf.nn.rnn_cell.LSTMCell(n_hidden)

# output : [batch_size, n_step, n_hidden],

# final_state : [batch_size, n_hidden]
#   final_state为(output_state_fw, output_state_bw)，包含了前向和后向最后的隐藏状态的组成的元组。 
#   output_state_fw 和 output_state_bw 的类型为LSTMStateTuple，由（c,h）组成，分别代表 memory cell 和 hidden state. 
output, final_state = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input, dtype=tf.float32)

# Attention
# [batch_size, n_step, n_hidden * 2]
output = tf.concat(output, axis=2)

# final_hidden_state : [batch_size, n_hidden * 2]
final_hidden_state = tf.concat(final_state[0], axis=1)

# final_hidden_state : [batch_size, n_hidden * 2, 1]
final_hidden_state = tf.expand_dims(final_hidden_state, 2)

# tf.matmul(output, final_hidden_state):[batch_size, n_step, 1];
# attn_weights : [batch_size, n_step]
attn_weights = tf.squeeze(tf.matmul(output, final_hidden_state), axis=2)
soft_attn_weights = tf.nn.softmax(attn_weights, 1)

# context : [batch_size, n_hidden * 2, 1]
context = tf.matmul(tf.transpose(output, [0, 2, 1]), tf.expand_dims(soft_attn_weights, 2))
# [batch_size, n_hidden * 2)]
context = tf.squeeze(context, 2)

model = tf.matmul(context, out)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model, labels=Y))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)

# Model-Predict
hypothesis = tf.nn.softmax(model)
predictions = tf.argmax(hypothesis, 1)

# Training
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for epoch in range(5000):
        _, loss, attention = sess.run([optimizer, cost, soft_attn_weights], feed_dict={X: input_batch, Y: target_batch})
        if (epoch + 1) % 1000 == 0:
            print('Epoch:', '%06d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

    # Test
    test_text = 'sorry hate you'
    tests = [np.asarray([word_dict[n] for n in test_text.split()])]

    predict = sess.run([predictions], feed_dict={X: tests})
    result = predict[0][0]
    if result == 0:
        print(test_text, "is Bad Mean...")
    else:
        print(test_text, "is Good Mean!!")

    fig = plt.figure(figsize=(6, 3))  # [batch_size, n_step]
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')
    ax.set_xticklabels([''] + ['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)
    ax.set_yticklabels([''] + ['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'],
                       fontdict={'fontsize': 14})
    plt.show()
